#!/usr/bin/env python3
# Copyright (c) 2025 Jason Lowe-Power
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met: redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer;
# redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution;
# neither the name of the copyright holders nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""
This module implements a terminal client that connects to gem5's terminal server.
It's functionally equivalent to the C implementation in term.c, providing both
TCP/IP and Unix domain socket connectivity options.
"""

import argparse
import os
import select
import socket
import sys
import termios
import time
from typing import (
    Tuple,
    Union,
)


class TerminalClient:
    """
    A terminal client that connects to gem5's terminal server.

    This class handles:
    - Terminal settings (raw mode)
    - Network connections (TCP/IP and Unix domain sockets)
    - Bidirectional data transfer
    - Escape sequence handling (~.)
    """

    def __init__(self):
        # Store original terminal settings for restoration
        self.saved_term_settings = None
        # Track escape sequence state (~.)
        self.escape_mode = False
        self.logfile = None
        # Add terminal output mode flags
        self.cflag = None

    def setup_logging(self, logpath: str) -> None:
        self.logfile = open(logpath, "wb")

    def log_data(self, data: bytes, direction: str) -> None:
        if self.logfile:
            timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
            self.logfile.write(
                f"[{timestamp}] {direction}: {data!r}\n".encode()
            )

    def setup_raw_term(self) -> None:
        """
        Configure terminal for raw mode operation.

        Raw mode is needed to handle input/output directly without buffering
        or input preprocessing.
        """
        # Save original terminal settings
        fd = sys.stdin.fileno()
        self.saved_term_settings = termios.tcgetattr(fd)

        # Get current settings for modification
        new_settings = termios.tcgetattr(fd)

        # Save original c_oflag for reference
        self.cflag = new_settings[1]

        # Modify input flags
        new_settings[0] &= ~(
            termios.ISTRIP
            | termios.ICRNL
            | termios.IGNCR
            | termios.ICRNL
            | termios.IXOFF
            | termios.IXON
        )

        # Keep output post-processing on and ensure ONLCR (map NL to CR-NL)
        new_settings[1] |= termios.OPOST | termios.ONLCR

        # Modify local flags - turn off canonical mode, echo, etc.
        new_settings[3] &= ~(termios.ISIG | termios.ICANON | termios.ECHO)

        # Set read parameters
        new_settings[6][termios.VMIN] = 1
        new_settings[6][termios.VTIME] = 0

        termios.tcsetattr(fd, termios.TCSANOW, new_settings)

    def restore_term(self) -> None:
        """Restore terminal to its original settings."""
        if self.saved_term_settings:
            termios.tcsetattr(
                sys.stdin.fileno(), termios.TCSANOW, self.saved_term_settings
            )

    def connect(self, host: str, port: Union[str, int]) -> socket.socket:
        """
        Create a connection to either a TCP/IP host or Unix domain socket.

        Args:
            host: Either a hostname or '--unix' for Unix domain sockets
            port: Port number for TCP/IP or path for Unix domain socket

        Returns:
            Connected socket object
        """
        if host == "--unix":
            return self._connect_unix(port)
        return self._connect_inet(host, str(port))

    def _connect_inet(self, host: str, port: str) -> socket.socket:
        """
        Establish a TCP/IP connection to the specified host and port.

        Args:
            host: Hostname to connect to
            port: Port number as string

        Returns:
            Connected socket object

        Raises:
            ConnectionError: If connection cannot be established
        """
        for res in socket.getaddrinfo(
            host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
        ):
            af, socktype, proto, _, addr = res
            try:
                s = socket.socket(af, socktype, proto)
                s.connect(addr)
                return s
            except OSError:
                continue
        raise ConnectionError(f"Could not connect to {host}:{port}")

    def _connect_unix(self, path: str) -> socket.socket:
        """
        Connect to a Unix domain socket.

        Supports both regular and abstract socket addresses.
        Abstract sockets are specified with a leading '@'.

        Args:
            path: Path to Unix domain socket or @name for abstract sockets

        Returns:
            Connected socket object
        """
        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        if path.startswith("@"):
            # Abstract unix domain socket
            addr = "\0" + path[1:]
        else:
            addr = path
        s.connect(addr)
        return s

    def read_write_loop(self, sock: socket.socket) -> None:
        """
        Main I/O loop handling bidirectional data transfer.

        Monitors both the socket and stdin for data, handles escape sequences,
        and manages connection status.

        Features:
        - Periodic connection checks
        - Escape sequence handling (~. to quit)
        - Binary-safe data transfer
        - Graceful shutdown handling

        Args:
            sock: Connected socket to read/write from/to
        """
        while True:
            try:
                readable, _, _ = select.select([sock, sys.stdin], [], [], 0.1)
                if not readable:
                    continue

                for fd in readable:
                    if fd is sock:
                        # Socket data - pass through directly
                        try:
                            data = sock.recv(4096)
                            if not data:
                                return
                            sys.stdout.buffer.write(data)
                            sys.stdout.buffer.flush()
                            self.log_data(data, "recv")
                        except OSError:
                            return
                    else:
                        # Terminal input - handle line wrapping
                        try:
                            data = os.read(sys.stdin.fileno(), 4096)
                            if not data:
                                sock.shutdown(socket.SHUT_WR)
                                continue

                            # Handle escape sequence
                            if self.escape_mode:
                                if data[0] == ord("."):
                                    print("quit!")
                                    return
                                if data[0] != ord("~"):
                                    sock.send(b"~")
                                self.escape_mode = False
                            else:
                                self.escape_mode = data[0] == ord("~")
                                if self.escape_mode:
                                    continue

                            # Send to socket
                            sock.send(data)
                            self.log_data(data, "send")
                        except OSError:
                            return
            except OSError:
                return

    def update_window_size(self, sock: socket.socket) -> None:
        import fcntl
        import struct
        import termios

        # Get terminal size
        size = struct.pack("HHHH", 0, 0, 0, 0)
        size = fcntl.ioctl(sys.stdin.fileno(), termios.TIOCGWINSZ, size)
        rows, cols, _, _ = struct.unpack("HHHH", size)

        # Send to gem5
        msg = struct.pack("!HH", rows, cols)
        sock.send(
            b"\xff\xff\xff\xff" + msg
        )  # Special sequence for window size


def main():
    """
    Main entry point for the terminal client.

    Handles:
    - Command line argument parsing
    - Terminal setup/restoration
    - Connection establishment
    - Error handling and cleanup
    """
    # Setup argument parser
    parser = argparse.ArgumentParser(
        description="Connect to a gem5 terminal server over TCP/IP or Unix domain socket"
    )
    parser.add_argument(
        "--unix",
        dest="unix",
        action="store_true",
        help="Connect to unix domain socket",
    )
    parser.add_argument(
        "host", nargs="?", default="localhost", help="Hostname to connect to"
    )
    parser.add_argument("port", help="Port number or socket path")
    parser.add_argument("--log", help="Log terminal session to file")
    parser.add_argument(
        "--buffer-size",
        type=int,
        default=4096,
        help="Socket buffer size (default: 4096)",
    )
    parser.add_argument(
        "--timeout",
        type=float,
        default=1.0,
        help="Socket timeout in seconds (default: 1.0)",
    )
    parser.add_argument(
        "--no-escape",
        action="store_true",
        help="Disable escape sequence handling",
    )
    args = parser.parse_args()

    if not sys.stdin.isatty():
        sys.exit("Not connected to a terminal")

    client = TerminalClient()
    try:
        client.setup_raw_term()
        if args.log:
            client.setup_logging(args.log)
        host = "--unix" if args.unix else args.host
        sock = client.connect(host, args.port)
        client.read_write_loop(sock)
    except (OSError, ConnectionError) as e:
        sys.exit(f"Connection error: {e}")
    finally:
        client.restore_term()


if __name__ == "__main__":
    main()
